import requests
import functools
import asyncio
from concurrent.futures import ThreadPoolExecutor
import json
import os
import tiktoken
from datetime import datetime, timedelta
import subprocess
import re


from core import setting, mgClient
import pymongo as mg
import json 
import websocket
import random
db = mgClient['chat']
traffic_collection = db.get_collection('traiffic')
traffic_collection_4 = db.get_collection('traiffic_4')
traffic_collection_poe = db.get_collection('traiffic_poe')
user_collection = db.get_collection('user')


def check_traffic_4(ip):
    """
    It checks the traffic of the given IP address.
    
    :param ip: The IP address of the host to check
    """
    now = datetime.now()
    traffic = traffic_collection_4.find_one({"ip": ip})  # 查找指定IP地址的流量信息
    if not traffic:
        traffic = {"ip": ip, "data": [{"timestamp": now, "count": 1}]}
        traffic_collection_4.insert_one(traffic)  # 插入新的流量信息
        return False
    else:
        recent_count = sum([t["count"] for t in traffic["data"]])
        # 删除超过24小时的旧流量数据
        traffic["data"] = [t for t in traffic["data"] if now - t["timestamp"] <= timedelta(hours=24)]
        if recent_count >= 1:  # 设置流量限制为10分钟内最多60个请求
            return True
        else:
            traffic["data"].append({"timestamp": now, "count": 1})
            traffic_collection_4.update_one({"ip": ip}, {"$set": traffic})  # 更新流量信息
            return False
def consume_token(user_key, token_num):
    user = user_collection.find_one({"user_key": user_key})
    if user:
        token_num = user["token_num"] - token_num
        user_collection.update_one({"user_key": user_key}, {"$set": {"token_num": token_num}})

def check_user_token(user_key):
    user = user_collection.find_one({"user_key": user_key})
    if user:
        return user["token_num"]
    else:
        return 0

def add_user(user_key, token_num):
    user = user_collection.find_one({"user_key": user_key})
    if user:
        user_collection.update_one({"user_key": user_key}, {"$set": {"token_num": token_num}})
    else:
        user_collection.insert_one({"user_key": user_key, "token_num": token_num})
        
def generate_max():
    """
    It generates an error
    """
    # This generator yields a single JSON response
    response = "您的输入长度过长，请您检查是否打开了【弃用记忆优化】按钮，如打开，请关闭后再试一次，4.0只支持8K的上下文，您也可选择128K模型重新尝试。或者直接请建会话重试"
    yield response

async def async_request_4(prompt: str,url = "https://api.openai.com/v1/chat/completions",key_value = "Bearer yourkey", model = "gpt-4-1106-preview"):
    key_values =[
        "yourkey"
    ]
    key_value = random.choice(key_values)
    data = {
        "model": model,
        "messages": prompt,
        "temperature": 1,
        # "max_tokens": 4096,
        "stream": True,  # Enable streaming API
    }
    headers = {
        "Content-Type": "application/json",
        "Authorization": key_value,
    }
    url = url
    loop = asyncio.get_running_loop()
    request = functools.partial(requests.post, url, headers=headers, json=data, stream = True, timeout = 30)
    try:
        with ThreadPoolExecutor() as executor:
            response = await loop.run_in_executor(
                executor, request
            )
    except:
        # 返回一个空的response，状态码为404
        print(response.status_code)
        response = requests.Response()
        response.status_code = 500
    return response
def check_traffic(ip):
    """
    It checks the traffic of the given IP address.
    
    :param ip: The IP address of the host to check
    """
    now = datetime.now()
    traffic = traffic_collection.find_one({"ip": ip})  # 查找指定IP地址的流量信息
    if not traffic:
        traffic = {"ip": ip, "data": [{"timestamp": now, "count": 1}]}
        traffic_collection.insert_one(traffic)  # 插入新的流量信息
        return False
    else:
        recent_count = sum([t["count"] for t in traffic["data"]])
        # 删除超过24小时的旧流量数据
        traffic["data"] = [t for t in traffic["data"] if now - t["timestamp"] <= timedelta(hours=24)]
        if recent_count >= 400:  # 设置流量限制为10分钟内最多60个请求
            return True
        else:
            traffic["data"].append({"timestamp": now, "count": 1})
            traffic_collection.update_one({"ip": ip}, {"$set": traffic})  # 更新流量信息
            return False

def generate_error_user():
    """
    It generates an error
    """
    # This generator yields a single JSON response
    response = "没钱啦！"
    yield response

def num_tokens_from_message(messages, model="gpt-3.5-turbo-0301"):
    """
    It takes a list of messages and returns the number of tokens that would be used to encode them
    
    :param messages: a list of messages, where each message is a dictionary with keys "name" and
    "content"
    :param model: The model to use, defaults to gpt-3.5-turbo-0301 (optional)
    :return: The number of tokens used by a list of messages.
    """
    
    """Returns the number of tokens used by a list of messages."""
    try:
        encoding = tiktoken.encoding_for_model(model)
    except KeyError:
        encoding = tiktoken.get_encoding("cl100k_base")
    if model == "gpt-3.5-turbo-0301":  # note: future models may deviate from this
        num_tokens = 0
        for message in messages:
            num_tokens += 4  # every message follows <im_start>{role/name}\n{content}<im_end>\n
            for key, value in message.items():
                num_tokens += len(encoding.encode(value))
                if key == "name":  # if there's a name, the role is omitted
                    num_tokens += -1  # role is always required and always 1 token
        num_tokens += 2  # every reply is primed with <im_start>assistant
        return num_tokens
    else:
        raise NotImplementedError(f"""num_tokens_from_messages() is not presently implemented for model {model}.
See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens.""")

    
def get_api_key(api_keys, api_key_index):
    """
    It returns the next API key in the list of API keys
    :return: The api_key is being returned.
    """
    print(api_key_index)
    api_key = api_keys[api_key_index]
    api_key_index = (api_key_index + 1) % len(api_keys)
    return api_key, api_key_index

def get_api_keys(key_file_path):
    """
    It opens the file at the path specified by the argument key_file_path, reads each line of the
    file, strips the newline character from each line, and returns a list of the lines
    
    :param key_file_path: The path to the file containing the API keys
    :return: A list of API keys
    """
    api_keys = []
    with open(key_file_path, 'r') as f:
        for line in f:
            api_keys.append(line.strip())
    return api_keys

def generate_error_key():
    """
    It generates an error
    """
    # This generator yields a single JSON response
    response = "每个IP每24小时只能使用1次4.0，您必须申请一个有效key才能继续免费使用本服务，具体方式请访问微信链接：https://mp.weixin.qq.com/s/1ZbctH6Iaa7LZkguW6ZnVA"
    yield response
def generate_error_traffic():
    """
    It generates an error
    """
    # This generator yields a single JSON response
    response = "我们聊的太多了，休息一会儿再来吧。"
    yield response
def generate_error():
    """
    It generates an error
    """
    # This generator yields a single JSON response
    response = "出了点小状况，大概率是开发者太穷租不起大流量，或者调用的服务有问题，可以再试一次吗？"
    yield response
def generate_error_max():
    """
    It generates an error
    """
    # This generator yields a single JSON response
    response = "对话内容太多了，您最好新开一个会话再试一次或者换用其他模型。"
    yield response


async def async_request(prompt: str,api_key: str,model = "gpt-3.5-turbo-0613", url = "https://api.openai.com/v1/chat/completions"):
    data = {
        "model": model,
        "messages": prompt,
        "temperature": 1,
      #  "max_tokens": max_tokens,
        "stream": True,  # Enable streaming API
    }
    headers = {
        "Content-Type": "application/json",
        "Authorization": "Bearer "+api_key,
    }
    url = url
    loop = asyncio.get_running_loop()
    request = functools.partial(requests.post, url, headers=headers, json=data, stream = True, timeout = 20)
    try:
        with ThreadPoolExecutor() as executor:
            response = await loop.run_in_executor(
                executor, request
            )
    except:
        print(response.text)
        response = requests.Response()
        response.status_code = 500
    return response
def generate_text(response,save_file_path=None, save_flag = 0,user_key=None):
    """
    It takes the response, and then it iterates through the response,
    and it finds the text that is in the response, and then it
    yields it
    
    :param response: the response from the server
    :param summary_flag: if it's 1, then the function will return a summary of the conversation
    before the answer, defaults to 0 (optional)
    """
    all_text = []
    for event in response.iter_content(chunk_size=None):
        response_str = event.decode('utf-8')
        strs = response_str.split('data: ')
        for str in strs:
            try:
                response_dict = json.loads(str)
                if "choices" in response_dict and response_dict["choices"]:
                    text = response_dict["choices"][0]["delta"].get("content")
                    if text:
                        # text = re.sub(r'\\n', '\n', text)
                        # text = re.sub(r'\\t', '\t', text)
                        all_text.append(text)
                        yield text
            except:
                pass
    all_text = '\n'.join(all_text)
    if save_flag:
        with open(save_file_path, 'w') as f:
            f.write('\n'.join(all_text))
    if save_flag:
        with open(save_file_path,'a') as f:
            f.write("\'}\n")
    if user_key:
        message = [{"role":"assistant","content":all_text}]
        token_num = num_tokens_from_message(message)
        print('adada')
        consume_token(user_key,token_num)
        print('adadfsaf')